{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VGG16 Net Surgery\n",
    "VGG16 Transfer Learning After 3-to-4-Channel Input Conversion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Background\n",
    "The weakly supervised segmentation techniques presented in the \"Simply Does It\" paper use a backbone convnet (either DeepLab or VGG16 network) **pre-trained on ImageNet**. This pre-trained network takes RGB images as an input (W x H x 3). Remember that the weakly supervised version is trained using **4-channel inputs: RGB + a binary mask with a filled bounding box of the object instance**. Therefore, we need to **perform net surgery and create a 4-channel input version** of the VGG16 net, initialized with the 3-channel parameter values **except** for the additional convolutional filters (we use Gaussian initialization for them).\n",
    "\n",
    "Here's how a typical VGG16 convnet looks like:\n",
    "\n",
    "![](img/vgg16.png)\n",
    "\n",
    "In the picture above, we're modifying the first block on the left."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "net_surgery.ipynb\n",
    "\n",
    "VGG16 Transfer Learning After 3-to-4-Channel Input Conversion\n",
    "\n",
    "Written by Phil Ferriere\n",
    "\n",
    "Licensed under the MIT License (see LICENSE for details)\n",
    "\n",
    "Based on:\n",
    "  - https://github.com/minhnhat93/tf_object_detection_multi_channels/blob/master/edit_checkpoint.py\n",
    "    Written by SNhat M. Nguyen\n",
    "    Unknown code license\n",
    "\"\"\"\n",
    "from tensorflow.python import pywrap_tensorflow\n",
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_input_channels = 4 # AStream uses 4-channel inputs\n",
    "init_method = 'gaussian' # ['gaussian'|'spread_average'|'zeros']\n",
    "input_path = 'models/vgg_16_3chan.ckpt' # copy of checkpoint in http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz\n",
    "output_path = 'models/vgg_16_4chan.ckpt'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Surgery"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are the VGG16 stage 1 parameters we'll want to modify:\n",
    "```\n",
    "(dlwin36tfwss) Phil@SERVERP E:\\repos\\tf-wss\\tfwss\\tools\n",
    "$  python -m inspect_checkpoint --file_name=../models/vgg_16_3chan.ckpt | grep -i conv1_1\n",
    "vgg_16/conv1/conv1_1/weights (DT_FLOAT) [3,3,3,64]\n",
    "vgg_16/conv1/conv1_1/biases (DT_FLOAT) [64]\n",
    "```\n",
    "First, let's find the correct tensor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading checkpoint...\n",
      "...done loading checkpoint.\n",
      "Tensor vgg_16/conv1/conv1_1/weights of shape (3, 3, 3, 64) located.\n"
     ]
    }
   ],
   "source": [
    "print('Loading checkpoint...')\n",
    "reader = pywrap_tensorflow.NewCheckpointReader(input_path)\n",
    "print('...done loading checkpoint.')\n",
    "\n",
    "var_to_shape_map = reader.get_variable_to_shape_map()\n",
    "var_to_edit_name = 'vgg_16/conv1/conv1_1/weights'\n",
    "\n",
    "for key in sorted(var_to_shape_map):\n",
    "    if key != var_to_edit_name:\n",
    "        var = tf.Variable(reader.get_tensor(key), name=key, dtype=tf.float32)\n",
    "    else:\n",
    "        var_to_edit = reader.get_tensor(var_to_edit_name)\n",
    "        print('Tensor {} of shape {} located.'.format(var_to_edit_name, var_to_edit.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's edit the tensor and initialize it according to the chosen init method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess = tf.Session()\n",
    "if init_method != 'gaussian':\n",
    "    print('Error: Unimplemented initialization method')\n",
    "new_channels_shape = list(var_to_edit.shape)\n",
    "new_channels_shape[2] = num_input_channels - 3\n",
    "gaussian_var = tf.random_normal(shape=new_channels_shape, stddev=0.001).eval(session=sess)\n",
    "new_var = np.concatenate([var_to_edit, gaussian_var], axis=2)\n",
    "new_var = tf.Variable(new_var, name=var_to_edit_name, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's update the network parameters and the save the updated model to disk:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'models/vgg_16_4chan.ckpt'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver()\n",
    "saver.save(sess, output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Verification\n",
    "Verify the result of this surgery by looking at the output of the following commands:\n",
    "```\n",
    "$ python -m inspect_checkpoint --file_name=../models/vgg_16_3chan.ckpt --tensor_name=vgg_16/conv1/conv1_1/weights > vgg_16_3chan-conv1_1-weights.txt\n",
    "$ python -m inspect_checkpoint --file_name=../models/vgg_16_4chan.ckpt --tensor_name=vgg_16/conv1/conv1_1/weights > vgg_16_4chan-conv1_1-weights.txt\n",
    "```\n",
    "You should see the following values in the first filter:\n",
    "```\n",
    "# 3-channel VGG16\n",
    "3,3,3,0\n",
    "[[[ 0.4800154   0.55037946  0.42947057]\n",
    "  [ 0.4085474   0.44007453  0.373467  ]\n",
    "  [-0.06514555 -0.08138704 -0.06136011]]\n",
    "\n",
    " [[ 0.31047726  0.34573907  0.27476987]\n",
    "  [ 0.05020237  0.04063221  0.03868078]\n",
    "  [-0.40338343 -0.45350131 -0.36722335]]\n",
    "\n",
    " [[-0.05087169 -0.05863491 -0.05746817]\n",
    "  [-0.28522751 -0.33066967 -0.26224968]\n",
    "  [-0.41851634 -0.4850302  -0.35009676]]]\n",
    "  \n",
    "# 4-channel VGG16\n",
    "3,3,4,0\n",
    "[[[  4.80015397e-01   5.50379455e-01   4.29470569e-01   1.13388560e-04]\n",
    "  [  4.08547401e-01   4.40074533e-01   3.73466998e-01   7.61439209e-04]\n",
    "  [ -6.51455522e-02  -8.13870355e-02  -6.13601133e-02   4.74345696e-04]]\n",
    "\n",
    " [[  3.10477257e-01   3.45739067e-01   2.74769872e-01   4.11637186e-04]\n",
    "  [  5.02023660e-02   4.06322069e-02   3.86807770e-02   1.38304755e-03]\n",
    "  [ -4.03383434e-01  -4.53501314e-01  -3.67223352e-01   1.28411280e-03]]\n",
    "\n",
    " [[ -5.08716851e-02  -5.86349145e-02  -5.74681684e-02  -6.34787197e-04]\n",
    "  [ -2.85227507e-01  -3.30669671e-01  -2.62249678e-01  -1.77454809e-03]\n",
    "  [ -4.18516338e-01  -4.85030204e-01  -3.50096762e-01   2.10441509e-03]]]\n",
    "  \n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
